Comparing machine learning approaches to predict SEEG accuracy

Stereoencephalography (SEEG) is a technique used in drug-resistant epilepsy patients that may be a candidate for surgical resection of the epileptogenic zone. Multiple electrodes are placed using a so-called "frame based" stereotactic approach, in our case using the Leksell frame. In our previous paper "Methodology, outcome, safety and in vivo accuracy in traditional frame-based stereoelectroencephalography" by Van der Loo et al (2017) we reported on SEEG electrode implantation accuracy in a cohort of 71 patients who were operated between September 2008 and April 2016, in whom a total of 902 electrodes were implanted. Data for in vivo application accuracy analysis were available for 866 electrodes.

The goal of the current project is to use a public version of this dataset (without any personal identifiers) to predict electrode implantation accuracy by using and comparing different machine learning approaches.

Pieter Kubben, MD, PhD
neurosurgeon @ Maastricht University Medical Center, The Netherlands

For any questions you can reach me by email or on Twitter.

Data description

The public dataset contains these variables:

  • PatientPosition: patient position during surgery (nominal: supine, prone)
  • Contacts: nr of contacts of electrode implanted (ordinal: 5, 8, 10, 12, 15, 18)
  • ElectrodeType: describes trajectory type (nominal: oblique, orthogonal). Oblique refers to implantation using the Leksell arc, and orthogonal using a dedicated L-piece mounted on the frame (mostly implants in temporal lobe) when arc angles become too high (approx > 155°) or too low (approx < 25°)
  • PlanningX: planned Cartesian X coord of target (numeric, in mm)
  • PlanningY: planned Cartesian Y coord of target (numeric, in mm)
  • PlanningZ: planned Cartesian Z coord of target (numeric, in mm)
  • PlanningRing: planned ring coord, the trajectory direction in sagittal plane (numeric, in degrees); defines entry
  • PlanningArc: planned arc coord (trajectory direction in coronal plane (numeric, in degrees); defines entry
  • DuraTipDistancePlanned: distance from dura mater (outer sheet covering the brain surface) to target (numeric, in mm)
  • EntryX: real Cartesian X coord of entry point (numeric, in mm)
  • EntryY: real Cartesian Y coord of entry point (numeric, in mm)
  • EntryZ: real Cartesian Z coord of entry point (numeric, in mm)
  • TipX: real Cartesian X coord of target point (numeric, in mm)
  • TipY: real Cartesian Y coord of target point (numeric, in mm)
  • TipZ: real Cartesian Z coord of target point (numeric, in mm)
  • SkinSkullDistance: distance between skin surfacce and skull surface (numeric, in mm)
  • SkullThickness: skull thickness (numeric, in mm)
  • SkullAngle: insertion angle of electrode relative to skull (numeric, in degrees)
  • ScrewLength: length of bone screw used to guide and fixate electrode (ordinal: 20, 25, 30, 35 mm)

The electrodes are the Microdeep depth electrodes by DIXI Medical.

To the limited extent possible in this case I tried to make these FAIR data and adhere to FAIR guiding principles. In practice this meant I introduced the topic, described my data and created a DOI.

Now let's get started.


In [1]:
# import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
%matplotlib inline
import warnings; warnings.simplefilter('ignore')
%xmode plain; # shorter error messages

# global setting whether to save figures or not
# will save as 300 dpi PNG - all filenames start with "fig_"
save_figures = False

In [2]:
# load data
electrodes = pd.read_csv('../data/electrodes_public.csv')

# find missing values
nan_rows = sum([True for idx,row in electrodes.iterrows() if any(row.isnull())])
print('Nr of rows with missing values:', nan_rows)


Nr of rows with missing values: 823

We will calculate target point localization error (TPLE) using the Euclidean distance and remove the entry data as we won't be using them (still wanted to share them in dataset though).


In [3]:
# calculate TPLE and remove entry data from dataframe
electrodes['TPLE'] = np.sqrt(np.square(electrodes['TipX'] - electrodes['PlanningX']) + 
                              np.square(electrodes['TipY'] - electrodes['PlanningY']) + 
                              np.square(electrodes['TipZ'] - electrodes['PlanningZ'])
                             ).round(1)
electrodes.drop(['EntryX', 'EntryY', 'EntryZ'], axis = 1, inplace = True)
electrodes.head()


Out[3]:
PatientPosition Contacts ElectrodeType PlanningX PlanningY PlanningZ PlanningRing PlanningArc DuraTipDistancePlanned TipX TipY TipZ SkinSkullDistance SkullThickness SkullAngle ScrewLength TPLE
0 Supine 18.0 Oblique 125.8 106.5 135.5 154.6 90.2 86.6 126.4 106.8 135.2 7.0 9.7 70.3 25.0 0.7
1 Supine 18.0 Oblique 130.6 131.0 136.4 155.4 96.6 105.9 134.7 132.9 136.0 9.4 7.4 63.8 30.0 4.5
2 Supine 10.0 Oblique 139.1 104.9 124.0 131.1 146.0 35.2 139.1 108.3 124.4 9.8 5.5 66.4 30.0 3.4
3 Supine 10.0 Oblique 137.7 112.2 115.0 96.2 137.1 35.9 136.4 115.2 115.0 8.4 6.5 66.3 25.0 3.3
4 Supine 12.0 Oblique 126.0 76.0 124.1 159.1 156.3 39.5 124.6 75.8 126.8 7.4 4.6 84.7 25.0 3.0

Now we will remove large outliers (difference between planned and real coord) in the Z-axis as electrode insertion length (depth) is influenced also by other factors (calculations regarding depth which could lead to either too superficial or too deep, but also possible malfixation of the screw cap which may cause loosening of the electrode and hence a more superficial position.. it won't migrate into the depth spontaneously). These are very limited numbers and would too much influence further analysis.


In [4]:
# check for outliers in Z axis
large_depth_error = electrodes[np.abs(electrodes['PlanningZ'] - electrodes['TipZ']) > 10]
print('Outliers in Z axis (> 10mm):\n\n', large_depth_error['TPLE'])

# remove outliers
electrodes.drop(large_depth_error.index, inplace = True)
print('\nNew dataframe shape:', electrodes.shape) # removed 6 rows


Outliers in Z axis (> 10mm):

 446    14.8
462    14.5
463    11.2
508    45.8
612    14.2
632    20.0
Name: TPLE, dtype: float64

New dataframe shape: (860, 17)

We need to structure our data properly for further analysis and convert the categorical variables (nominal, ordinal) to the category type.


In [5]:
# convert categorical columns to "category" dtype
catcols = ['PatientPosition', 'Contacts', 'ElectrodeType', 'ScrewLength']
for cat in catcols:
    electrodes[cat] = electrodes[cat].astype('category')

# confirm correct types for all columns now
electrodes.dtypes


Out[5]:
PatientPosition           category
Contacts                  category
ElectrodeType             category
PlanningX                  float64
PlanningY                  float64
PlanningZ                  float64
PlanningRing               float64
PlanningArc                float64
DuraTipDistancePlanned     float64
TipX                       float64
TipY                       float64
TipZ                       float64
SkinSkullDistance          float64
SkullThickness             float64
SkullAngle                 float64
ScrewLength               category
TPLE                       float64
dtype: object

Let's get a short description of our TPLE data.


In [6]:
# get summary data on TPLE
tple = electrodes['TPLE']
tple.describe().round(1)


Out[6]:
count    860.0
mean       3.4
std        2.3
min        0.2
25%        2.0
50%        2.9
75%        4.1
max       19.8
Name: TPLE, dtype: float64

We now have data in the right format, but for classification we need to bin the continuous outcome variables EPLE and TPLE into categories. Alternatively we could approach this as a regression problem, but given the relative limited amount of data classification should lead to a better prediction model that is still relevant for potential clinical use.

Let's create a new variable TPLE category for this purpose.


In [7]:
# create different possible cuts to create categories (and experiment with them)
tple_max = tple.max().round()
electrodes_3cat = pd.cut(tple, bins = [0, 2.5, 5, tple_max], labels = ['0 - 2.5', '2.5 - 5', '> 5'])
electrodes_4cat = pd.cut(tple, bins = [0, 2, 4, 6, tple_max], labels = ['0 - 2', '2 - 4', '4 - 6', '> 6'])
electrodes_4quant = pd.cut(tple, bins = [0, tple.quantile(.25), tple.median(), tple.quantile(.75), tple_max], labels = ['0 - 2', '2 - 3', '3 - 4', '> 4'])
electrodes_5cat = pd.cut(tple, bins = [0, 1, 2, 5, 10, tple_max], labels = ['0 - 1', '1 - 2', '2 - 5', '5 - 10', '> 10'])
electrodes_7cat = pd.cut(tple, bins = [0, 1, 2, 4, 6, 8, 10, tple_max], labels = ['0 - 1', '1 - 2', '2 - 4', '4 - 6', '6 - 8', '8 - 10', '> 10'])

# apply cut to create TPLE category column
electrodes['TPLE category'] = electrodes_5cat
nr_of_y_categories = len(electrodes['TPLE category'].unique()) # needed for confusion matrix and keras MLP

# check correct conversion with first 15 rows
electrodes[['TPLE', 'TPLE category']].head(15).T


Out[7]:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
TPLE 0.7 4.5 3.4 3.3 3 1.8 2.3 4.2 4.8 3.4 3.9 3.6 7.4 4.2 4.4
TPLE category 0 - 1 2 - 5 2 - 5 2 - 5 2 - 5 1 - 2 2 - 5 2 - 5 2 - 5 2 - 5 2 - 5 2 - 5 5 - 10 2 - 5 2 - 5

In [8]:
# count nr of items in each category
electrodes[['TPLE', 'TPLE category']].groupby('TPLE category').count().T


Out[8]:
TPLE category 0 - 1 1 - 2 2 - 5 5 - 10 > 10
TPLE 50 184 492 112 22

So we have a lof of electrodes in the 2-5mm TPLE category, which is explained well with an interquartile range of 2.0 - 4.1. Let's create some plots to learn more about our data.

Visualization

Let's first get an impression of variable distributions.


In [9]:
# plot TPLE distribution
tple.plot(kind = 'density', figsize = (15,5), title = 'TPLE density plot');
plt.xlim(0,10);
if save_figures:
    plt.savefig('fig_tple_density_plot.png', dpi = 300)



In [10]:
# plot variable distributions (density)
params = {'subplots': True, 'layout': (3,4), 'sharex': False, 'figsize': (16, 12)}
elec_visual = electrodes.drop(['TPLE'], axis = 1)
elec_visual.plot(kind='density', **params);
if save_figures:
    plt.savefig('fig_variables_density_plot.png', dpi = 300)



In [11]:
# alternatively, make a box plot
elec_visual.plot(kind='box', sharey = False, **params);
if save_figures:
    plt.savefig('fig_variables_boxplot.png', dpi = 300)


Now we have a visual impression of how variables are distributed. As we want to develop a model to predict TPLE category, it would be helpful to see to what extent variable distribution differs per TPLE category.


In [12]:
# first relate numeric variables to TPLE category
numcols = list(electrodes.columns[electrodes.dtypes == float].drop('TPLE'))
fig, axes = plt.subplots(4,3, figsize = (15,20))
fig.subplots_adjust(hspace=.5)
electrodes.boxplot(column = numcols, by = 'TPLE category', ax = axes);
if save_figures:
    plt.savefig('fig_tplecompare_boxplot.png', dpi = 300)


In contrast to numerical variables in which we can plot categories against actual feature values, in categorical columns we can "only" plot categories against the number of items in each category. This is what we will do next.

Note that these are absolute counts, and no ratios (e.g. supine position is used most and this reflects also below). So the point is to inspect the graphs for any remarkable ratio differences between categories and not simply look at the highest bars...


In [13]:
# show absolute TPLE count per category 
elec_count = 'Electrode count'
for cat in catcols:
    cat_count = electrodes[[cat, 'TPLE','TPLE category']].groupby([cat,'TPLE category']).count()
    cat_count.rename_axis({'TPLE': elec_count}, axis = 'columns', inplace = True)
    sns.factorplot(x = cat, y = elec_count, data = cat_count.reset_index(), kind = 'bar', hue = 'TPLE category',
                   size = 5, aspect = 3);
    if save_figures:
        plt.savefig('fig_tplecompare_cat_{}.png'.format(cat), dpi = 300)


Now, how do they correlate with each other? To be able to get correlations we first have to deal with missing data. We will create a temporary soluting here for visualisation and use another (for that purpose recommended) approach for predictive modelling.


In [14]:
# copy dataframe to temporary frame in which we'll impute missing values
elec_pairplot = electrodes

# create a utility function to check if we (still) have missing values present in our data
def print_missing(dataframe):
    '''checks which columns contain missing values and returns count'''
    df_na = dataframe.columns[electrodes.isnull().any()]
    
    if len(df_na) > 0:
        print('Missing data present in {} columns: \n'.format(len(df_na)))
        for c in df_na:
            print('- {} ({})'.format(c, dataframe[c].isnull().sum()))
    else:
        print('No missing data found! :-)')
        
print_missing(elec_pairplot)


Missing data present in 5 columns: 

- Contacts (56)
- PlanningRing (269)
- PlanningArc (269)
- DuraTipDistancePlanned (54)
- ScrewLength (534)

In [15]:
# use median values for numeric columns
for missing in ['PlanningRing', 'PlanningArc', 'DuraTipDistancePlanned']:
    elec_pairplot[missing] = elec_pairplot[missing].fillna(elec_pairplot[missing].median())

# use most frequent value for categorical columns
for missing in ['Contacts', 'ScrewLength']:
    most_frequent_value = elec_pairplot[missing].value_counts().index[0]
    print('Most frequent value in category "{}": {}'.format(missing, most_frequent_value))
    elec_pairplot[missing] = elec_pairplot[missing].fillna(most_frequent_value)

print_missing(elec_pairplot)


Most frequent value in category "Contacts": 18.0
Most frequent value in category "ScrewLength": 25.0
No missing data found! :-)

From practical experience I can confirm that we do use 18 contact points frequently often and it makes sense to use those here (only 56 missing... more advanced may be to correlate with other variables and decide then). Regarding screw length, 25mm is used by far the most, so even despite the fact that most values are missing (534 / 860) I still kept them in (in the original paper I referred to, increasing screw length seems to correspond with increasing TPLE).

Now let's look at some correlations.


In [16]:
# pairplot to correlate variable distributions per category.. busy plot
sns.pairplot(elec_pairplot[['PatientPosition', 'Contacts', 'ElectrodeType', 'DuraTipDistancePlanned', 
                            'SkinSkullDistance', 'SkullThickness', 'SkullAngle', 'ScrewLength', 'TPLE category']], 
             hue = 'TPLE category');
if save_figures:
    plt.savefig('fig_electrodes_corr_pairplot.png', dpi = 300)


Those are not exactly easily separable clusters... looks bad for further predictive modelling.... :-(

Alternatively we will use Peason's correlations to create a "heatmap". To include categorical variables we need to "one hot encode" them first (using pd.get_dummies()).


In [17]:
plt.figure(figsize = (15,12));
electrodes_dummies = pd.get_dummies(electrodes.drop(['TPLE category'], axis = 1))

# correlations for continuous variables only
# sns.heatmap(electrodes.corr(), square = True, annot = True, linewidths = .5, cmap = 'RdBu_r');

# correlations including categorical variables
sns.heatmap(electrodes_dummies.corr(), square = True, annot = False, cmap = 'RdBu_r');

if save_figures:
    plt.tight_layout() # needed to prevent cutting X-axis labels
    plt.savefig('fig_electrodes_corr_heatmap.png', dpi = 300)


TPLE classification

We will now use several machine learning classification approaches (manual, auto ML and a quick glance at deep learning) to predict TPLE category. I tried several feature selection and dimensionality reduction approaches that I left in place as comments as they did not contribute to better results.

Manual approach


In [30]:
from sklearn.preprocessing import LabelEncoder, PolynomialFeatures, Imputer, MinMaxScaler
from sklearn.feature_selection import SelectKBest, chi2, VarianceThreshold
elec_features = electrodes.drop(['TPLE', 'TPLE category'], axis = 1)

# encode categorical features (no one hot encoding to avoid creating too much features)
for cat in catcols:
    elec_features[cat] = LabelEncoder().fit_transform(elec_features[cat])

X = elec_features # to start with
# deal with missing values the sklearn way (and do not impute if not needed to)
X = Imputer(strategy = 'most_frequent').fit_transform(X)

# feature selection - none of these actually improved accuracy (not tried for all other outcomes)
# X = elec_features[['ScrewLength', 'PatientPosition', 'ElectrodeType', 'Contacts', 'SkullAngle']] 
# X = elec_features['SkullAngle'][:, np.newaxis]
# X = MinMaxScaler().fit_transform(X) # is it correct to do this on whole dataset?
# X = VarianceThreshold(threshold=(.8 * (1 - .8))).fit_transform(X)
# X = PolynomialFeatures(2, interaction_only = True).fit_transform(X)
# X = SelectKBest(chi2, k = 5).fit_transform(X, y)

# encode outcome (separate encoder, want to make sure to be able to reverse later)
le = LabelEncoder()
y = le.fit_transform(electrodes['TPLE category'])

In [31]:
# dimensionality reduction using PCA
from sklearn.decomposition import PCA
pca = PCA(n_components = X.shape[1])
X_pca = pca.fit_transform(X)

# plot PCA variance ratio
plt.figure(figsize = (10,6));
plt.plot(pca.explained_variance_ratio_.cumsum());
plt.title('Explained variance ratio by PCA components');
plt.xlabel('PCA component');
plt.ylabel('Cumulative variance ratio');
if save_figures:
    plt.savefig('fig_mlmanual_pca.png', dpi = 300)



In [32]:
# use 5 principal components... (explains > 95% of variance) even slightly reduces accuracy...
# X = PCA(5).fit_transform(X)

Now split the data intro train and test set (maybe I should have done this earlier, e.g. using a Pipeline, for future predictions, but for now I'll leave it as is).


In [33]:
# split data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 33, stratify = y)

In [39]:
# import modules
from sklearn.feature_selection import RFE
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from xgboost import XGBClassifier

# prepare models
models = {'LOG': LogisticRegression(),
          'LDA': LinearDiscriminantAnalysis(),
          'KNN': KNeighborsClassifier(),
          'CART': DecisionTreeClassifier(),
          'NB': GaussianNB(),
          'SVM': SVC(),
          'LSVM1': LinearSVC(penalty='l1', dual = False),
          # 'LSVM2': LinearSVC(),
          'RF': RandomForestClassifier(),
          'ADA': AdaBoostClassifier(),
          'XGB': XGBClassifier()
         }

Everything is set up. We will now evalute our models for accuracy and plot the results.


In [40]:
# evaluate each model in turn
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
names = []; results = []
print('MODEL: ACCURACY (STD)\n')

for name, model in models.items():
    cv_results = cross_val_score(model, X_train, y_train, cv = KFold(10), scoring = 'accuracy')
    names.append(name)
    results.append(cv_results)
    print('{}: {:.2f} ({:.2f})'.format(name, cv_results.mean(), cv_results.std()))
    # print(classification_report(y_test, model.fit(X_train, y_train).predict(X_test), target_names = le.classes_))


MODEL: ACCURACY (STD)

LOG: 0.57 (0.03)
LDA: 0.57 (0.04)
KNN: 0.47 (0.06)
CART: 0.42 (0.05)
NB: 0.48 (0.08)
SVM: 0.57 (0.04)
LSVM1: 0.57 (0.04)
RF: 0.52 (0.05)
ADA: 0.44 (0.08)
XGB: 0.57 (0.04)

In [43]:
# boxplot algorithm comparison
fig = plt.figure(figsize = (12,6))
plt.title('Accuracy comparison for different models')
ax = fig.add_subplot(111)
ax.set_ylim(0.3, 0.7)
plt.boxplot(results)
ax.set_xticklabels(names); # set xtick labels after plotting! (otherwise default values override custom labels)
if save_figures:
    plt.savefig('fig_mlmanual_boxplot.png', dpi = 300)



In [84]:
# create confusion matrix to compare true labels and predicted labels
from sklearn.metrics import confusion_matrix
model = LogisticRegression().fit(X_train, y_train)
y_pred = model.predict(X_test)
conf_mat_labels = sorted(list(electrodes['TPLE category'].unique()))
conf_mat = confusion_matrix(y_test, y_pred).T
# conf_mat = confusion_matrix(y_test, y_pred, labels = conf_mat_labels)

plt.figure(figsize = (7, 7));
plt.title('Confusion matrix for model predictions')
sns.heatmap(conf_mat, annot = True, fmt = 'd', cbar = False, square = True, cmap = 'Purples', linewidth = .5,
            xticklabels = conf_mat_labels, yticklabels = conf_mat_labels);
plt.xlabel('True TPLE category');
plt.ylabel('Predicted TPLE category');
if save_figures:
    plt.savefig('fig_mlmanual_confusionmatrix.png', dpi = 300)


Hmm... TPLE category 2-5 is overrepresented... how many variables of our test set do we actually have in each category?


In [85]:
print('Nr of values in each category for y_test:\n')
y_test_unique = np.unique(le.inverse_transform(y_test), return_counts=True)
print('Unique values:', y_test_unique[0])
print('Unique values count:', y_test_unique[1])


Nr of values in each category for y_test:

Unique values: ['0 - 1' '1 - 2' '2 - 5' '5 - 10' '> 10']
Unique values count: [ 15  55 148  33   7]

I am starting to doubt whether accuracy is the best outcome metric. Let's get some other metrics too for the model used for the confusion matrix above.


In [96]:
from sklearn.metrics import accuracy_score, classification_report
print('Evaluation for:\n\n', model)

# accuracy score
print('\n\nAccuracy score:', accuracy_score(y_test, y_pred).round(2))
print('\n\nClassification report: \n\n', classification_report(y_test, y_pred, target_names = conf_mat_labels))

# cross val score
cvscore = cross_val_score(model, X_test, y_test, scoring = 'accuracy', cv = 10)
print('\n\nCross val score (std): {:.2f} ({:.2f})'.format(cvscore.mean(), cvscore.std())) # why different?


Evaluation for:

 LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)


Accuracy score: 0.59


Classification report: 

              precision    recall  f1-score   support

      0 - 1       0.00      0.00      0.00        15
      1 - 2       0.43      0.05      0.10        55
      2 - 5       0.59      0.95      0.73       148
     5 - 10       0.57      0.24      0.34        33
       > 10       0.00      0.00      0.00         7

avg / total       0.51      0.59      0.48       258



Cross val score (std): 0.53 (0.07)

Let's say there is room for improvement.... hyperparameter tuning may help, we will do so using a cross-validated parameter grid search for SVM and XGB here.


In [28]:
# SVC tuning
from sklearn.model_selection import GridSearchCV
# SVC().get_params()
svm_param_grid = [
  {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
  {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
 ]
svm_param_grid = {'kernel': ['linear', 'rbf'], 'C': [1, 10]} # simple and less slow... 
svm_grid = GridSearchCV(SVC(), svm_param_grid, cv = 3, scoring = 'accuracy', n_jobs = -1)
%time svm_grid.fit(X_train, y_train)


CPU times: user 1min 4s, sys: 64.2 ms, total: 1min 4s
Wall time: 2min 20s
Out[28]:
GridSearchCV(cv=3, error_score='raise',
       estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False),
       fit_params={}, iid=True, n_jobs=-1,
       param_grid={'kernel': ['linear', 'rbf'], 'C': [1, 10]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring='accuracy', verbose=0)

In [29]:
print('Best SVM params:', svm_grid.best_params_)
print('Best SVM score:', svm_grid.best_score_.round(2))
print('Best SVM model:\n', svm_grid.best_estimator_)


Best SVM params: {'C': 10, 'kernel': 'linear'}
Best SVM score: 0.58
Best SVM model:
 SVC(C=10, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

In [30]:
# XGB tuning
# XGBClassifier().get_params()
xgb_param_grid = {'learning_rate': [0.001, 0.01, 0.1, 1.0],
                  'max_depth': [3, 5, 7],
                  'n_estimators': [100, 150, 200]
                 }
xgb_grid = GridSearchCV(XGBClassifier(), xgb_param_grid, cv = 3, scoring = 'accuracy', n_jobs = -1)
%time xgb_grid.fit(X_train, y_train)


CPU times: user 499 ms, sys: 50.7 ms, total: 550 ms
Wall time: 7.08 s
Out[30]:
GridSearchCV(cv=3, error_score='raise',
       estimator=XGBClassifier(base_score=0.5, colsample_bylevel=1, colsample_bytree=1,
       gamma=0, learning_rate=0.1, max_delta_step=0, max_depth=3,
       min_child_weight=1, missing=None, n_estimators=100, nthread=-1,
       objective='binary:logistic', reg_alpha=0, reg_lambda=1,
       scale_pos_weight=1, seed=0, silent=True, subsample=1),
       fit_params={}, iid=True, n_jobs=-1,
       param_grid={'learning_rate': [0.001, 0.01, 0.1, 1.0], 'max_depth': [3, 5, 7], 'n_estimators': [100, 150, 200]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring='accuracy', verbose=0)

In [31]:
print('Best XGB params:', xgb_grid.best_params_)
print('Best XGB score:', xgb_grid.best_score_.round(2))
print('Best XGB model:\n', xgb_grid.best_estimator_)


Best XGB params: {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 150}
Best XGB score: 0.58
Best XGB model:
 XGBClassifier(base_score=0.5, colsample_bylevel=1, colsample_bytree=1,
       gamma=0, learning_rate=0.01, max_delta_step=0, max_depth=3,
       min_child_weight=1, missing=None, n_estimators=150, nthread=-1,
       objective='multi:softprob', reg_alpha=0, reg_lambda=1,
       scale_pos_weight=1, seed=0, silent=True, subsample=1)

In [32]:
# plot XGB tree
from xgboost import plot_tree
fig, axes = plt.subplots(figsize = (15,7));
plot_tree(xgb_grid.best_estimator_, ax = axes, rankdir='LR');
if save_figures:
    plt.savefig('fig_mlmanual__xgbtree.png', dpi = 300); # will still look like shit


AutoML

Let's try an automated ML approach using TPOT by Randal Olson for this purpose... this can take a while (12-24h).

A short test can be run using TPOTClassifier(generations=5, population_size=50, verbosity=2).


In [33]:
# from tpot import TPOTClassifier
# tpot = TPOTClassifier(generations=50, population_size=50, verbosity=2, n_jobs = -1)
# tpot.fit(X_train, y_train)
# print('\nTPOT score: ', tpot.score(X_test, y_test))
# tpot.export('tpot_tple_classification.py')


Optimization Progress:   4%|▍         | 96/2550 [00:32<3:17:21,  4.83s/pipeline]
Generation 1 - Current best internal CV score: 0.5980249334640707
Optimization Progress:   6%|▌         | 145/2550 [00:53<1:27:48,  2.19s/pipeline]
Generation 2 - Current best internal CV score: 0.602983611150021
Optimization Progress:   8%|▊         | 193/2550 [01:10<56:09,  1.43s/pipeline]  
Generation 3 - Current best internal CV score: 0.602983611150021
Optimization Progress:   9%|▉         | 241/2550 [01:31<30:22,  1.27pipeline/s]
Generation 4 - Current best internal CV score: 0.602983611150021
Optimization Progress:  11%|█▏        | 291/2550 [02:17<57:36,  1.53s/pipeline]  
Generation 5 - Current best internal CV score: 0.6061213055049727
Optimization Progress:  13%|█▎        | 336/2550 [02:50<53:16,  1.44s/pipeline]  
Generation 6 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  15%|█▍        | 377/2550 [03:07<28:26,  1.27pipeline/s]
Generation 7 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  17%|█▋        | 424/2550 [03:31<43:16,  1.22s/pipeline]  
Generation 8 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  18%|█▊        | 471/2550 [03:49<28:42,  1.21pipeline/s]
Generation 9 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  20%|██        | 519/2550 [04:09<28:45,  1.18pipeline/s]
Generation 10 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  22%|██▏       | 566/2550 [04:37<23:30,  1.41pipeline/s]
Generation 11 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  24%|██▍       | 614/2550 [04:57<26:02,  1.24pipeline/s]
Generation 12 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  26%|██▌       | 661/2550 [05:30<29:48,  1.06pipeline/s]
Generation 13 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  28%|██▊       | 708/2550 [05:52<23:22,  1.31pipeline/s]
Generation 14 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  30%|██▉       | 755/2550 [06:12<1:05:58,  2.21s/pipeline]
Generation 15 - Current best internal CV score: 0.6197226502311248
Optimization Progress:  31%|███▏      | 800/2550 [06:40<1:18:20,  2.69s/pipeline]
Generation 16 - Current best internal CV score: 0.629555960218518
Optimization Progress:  33%|███▎      | 848/2550 [07:16<47:04,  1.66s/pipeline]  
Generation 17 - Current best internal CV score: 0.629555960218518
Optimization Progress:  35%|███▌      | 896/2550 [07:50<23:38,  1.17pipeline/s]
Generation 18 - Current best internal CV score: 0.6329457907269926
Optimization Progress:  37%|███▋      | 943/2550 [08:20<40:03,  1.50s/pipeline]  
Generation 19 - Current best internal CV score: 0.6329457907269926
Optimization Progress:  39%|███▉      | 991/2550 [08:43<40:34,  1.56s/pipeline]
Generation 20 - Current best internal CV score: 0.6344726152122147
Optimization Progress:  41%|████      | 1036/2550 [09:11<19:43,  1.28pipeline/s]
Generation 21 - Current best internal CV score: 0.6445580613531308
Optimization Progress:  42%|████▏     | 1083/2550 [09:35<53:53,  2.20s/pipeline]  
Generation 22 - Current best internal CV score: 0.6446841294298922
Optimization Progress:  44%|████▍     | 1131/2550 [10:10<36:40,  1.55s/pipeline]  
Generation 23 - Current best internal CV score: 0.6529906149320633
Optimization Progress:  46%|████▌     | 1175/2550 [10:44<31:20,  1.37s/pipeline]
Generation 24 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  48%|████▊     | 1220/2550 [11:09<50:46,  2.29s/pipeline]  
Generation 25 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  50%|████▉     | 1264/2550 [11:26<27:43,  1.29s/pipeline]
Generation 26 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  51%|█████▏    | 1310/2550 [12:03<18:12,  1.13pipeline/s]
Generation 27 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  53%|█████▎    | 1360/2550 [12:25<1:07:01,  3.38s/pipeline]
Generation 28 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  55%|█████▌    | 1406/2550 [13:02<38:28,  2.02s/pipeline]  
Generation 29 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  57%|█████▋    | 1454/2550 [13:20<21:21,  1.17s/pipeline]
Generation 30 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  59%|█████▉    | 1501/2550 [13:43<14:04,  1.24pipeline/s]
Generation 31 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  60%|██████    | 1541/2550 [14:11<11:16,  1.49pipeline/s]
Generation 32 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  62%|██████▏   | 1588/2550 [14:28<08:21,  1.92pipeline/s]
Generation 33 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  64%|██████▍   | 1634/2550 [14:58<08:17,  1.84pipeline/s]
Generation 34 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  66%|██████▌   | 1681/2550 [15:14<06:02,  2.40pipeline/s]
Generation 35 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  68%|██████▊   | 1727/2550 [15:39<07:14,  1.89pipeline/s]
Generation 36 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  69%|██████▉   | 1772/2550 [16:08<06:23,  2.03pipeline/s]
Generation 37 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  71%|███████   | 1816/2550 [16:35<06:42,  1.82pipeline/s]
Generation 38 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  73%|███████▎  | 1859/2550 [16:55<05:54,  1.95pipeline/s]
Generation 39 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  75%|███████▍  | 1905/2550 [17:33<07:34,  1.42pipeline/s]
Generation 40 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  77%|███████▋  | 1951/2550 [18:06<06:42,  1.49pipeline/s]
Generation 41 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  78%|███████▊  | 1994/2550 [18:27<05:05,  1.82pipeline/s]
Generation 42 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  80%|███████▉  | 2036/2550 [19:08<06:54,  1.24pipeline/s]
Generation 43 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  82%|████████▏ | 2084/2550 [19:29<05:04,  1.53pipeline/s]
Generation 44 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  84%|████████▎ | 2130/2550 [20:02<04:36,  1.52pipeline/s]
Generation 45 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  85%|████████▌ | 2178/2550 [20:31<05:24,  1.15pipeline/s]
Generation 46 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  87%|████████▋ | 2222/2550 [21:04<04:17,  1.27pipeline/s]
Generation 47 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  89%|████████▉ | 2264/2550 [21:31<03:17,  1.45pipeline/s]
Generation 48 - Current best internal CV score: 0.6646449082504552
Optimization Progress:  91%|█████████ | 2310/2550 [21:51<02:25,  1.65pipeline/s]
Generation 49 - Current best internal CV score: 0.6646449082504552

Generation 50 - Current best internal CV score: 0.6646449082504552

Best pipeline: RandomForestClassifier(FastICA(VarianceThreshold(input_matrix, VarianceThreshold__threshold=0.4), FastICA__tol=0.1), RandomForestClassifier__bootstrap=False, RandomForestClassifier__criterion=gini, RandomForestClassifier__max_features=0.3, RandomForestClassifier__min_samples_leaf=1, RandomForestClassifier__min_samples_split=10, RandomForestClassifier__n_estimators=100)

TPOT score:  0.643410852713

The exported file tpot_tple_classification.py contains the full pipeline from the fitted model. We'll import it below so we can manually use it without having to re-run the full TPOT process.


In [54]:
# construct TPOT fitted pipeline as model
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import FastICA
tpot_model = make_pipeline(
    VarianceThreshold(threshold=0.4),
    FastICA(tol=0.1),
    RandomForestClassifier(bootstrap=False, criterion="gini", max_features=0.3, min_samples_leaf=1, min_samples_split=10, n_estimators=100)
).fit(X_train, y_train)
y_pred_tpot = tpot_model.predict(X_test)

In [61]:
plt.figure(figsize = (7, 7));
plt.title('Confusion matrix for TPOT predictions')
sns.heatmap(confusion_matrix(y_test, tpot_model.predict(X_test)).T, cbar = False, annot = True, fmt = 'd',
            square = True, cmap = 'Oranges', linewidth = .5, 
            xticklabels = conf_mat_labels, yticklabels = conf_mat_labels);
plt.xlabel('True TPLE category');
plt.ylabel('Predicted TPLE category');
if save_figures:
    plt.savefig('fig_tpot_confusionmatrix.png', dpi = 300)



In [95]:
print('Evaluation for fitted TPOT pipeline:\n\n', tpot_model)

# accuracy score
tpot_accscore = accuracy_score(y_test, y_pred_tpot)
print('\n\nAccuracy score: {:.2f}'.format(tpot_accscore))
print('\n\nClassification report: \n\n', classification_report(y_test, y_pred_tpot, target_names = conf_mat_labels))

# cross val score
tpot_cvscore = cross_val_score(tpot_model, X_test, y_test, scoring = 'accuracy', cv = 10)
print('\n\nCross val score (std): {:.2f} ({:.2f})'.format(tpot_cvscore.mean(), tpot_cvscore.std()))


Evaluation for fitted TPOT pipeline:

 Pipeline(steps=[('variancethreshold', VarianceThreshold(threshold=0.4)), ('fastica', FastICA(algorithm='parallel', fun='logcosh', fun_args=None, max_iter=200,
    n_components=None, random_state=None, tol=0.1, w_init=None,
    whiten=True)), ('randomforestclassifier', RandomForestClassifier(bootstrap=False, ...mators=100, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False))])


Accuracy score: 0.69


Classification report: 

              precision    recall  f1-score   support

      0 - 1       0.00      0.00      0.00        15
      1 - 2       0.57      0.47      0.51        55
      2 - 5       0.72      0.91      0.81       148
     5 - 10       0.68      0.45      0.55        33
       > 10       1.00      0.14      0.25         7

avg / total       0.65      0.69      0.65       258



Cross val score (std): 0.59 (0.08)

Deep learning

As a quick glance to deep learning we will apply a Multilayer Perceptron (MLP) for multi-class softmax classification using stochastic gradient descent as an optimizer. The code is borrowed from the official keras Sequential model guide and adapted where needed.


In [59]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.optimizers import SGD
from keras.utils.np_utils import to_categorical


Using TensorFlow backend.

In [64]:
keras_model = Sequential()
keras_model.add(Dense(64, activation='relu', input_dim=len(elec_features.columns)))
keras_model.add(Dropout(0.5))
keras_model.add(Dense(64, activation='relu'))
keras_model.add(Dropout(0.5))
keras_model.add(Dense(nr_of_y_categories, activation='softmax'))

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
keras_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
keras_model.fit(X_train, to_categorical(y_train), epochs=100, batch_size=10)
y_pred_keras = keras_model.predict(X_test)

scores = keras_model.evaluate(X_test, to_categorical(y_test), batch_size=128)
print('\n\n{}: {:.2f}'.format(keras_model.metrics_names[1], scores[1]))


Epoch 1/100
602/602 [==============================] - 0s - loss: 7.4488 - acc: 0.5365      
Epoch 2/100
602/602 [==============================] - 0s - loss: 7.8016 - acc: 0.5150     
Epoch 3/100
602/602 [==============================] - 0s - loss: 7.1487 - acc: 0.5565     
Epoch 4/100
602/602 [==============================] - 0s - loss: 7.1219 - acc: 0.5581     
Epoch 5/100
602/602 [==============================] - 0s - loss: 7.2558 - acc: 0.5498     
Epoch 6/100
602/602 [==============================] - 0s - loss: 7.0149 - acc: 0.5648     
Epoch 7/100
602/602 [==============================] - 0s - loss: 7.1956 - acc: 0.5532     
Epoch 8/100
602/602 [==============================] - 0s - loss: 7.3629 - acc: 0.5432     
Epoch 9/100
602/602 [==============================] - 0s - loss: 7.1487 - acc: 0.5565     
Epoch 10/100
602/602 [==============================] - 0s - loss: 7.4432 - acc: 0.5382     
Epoch 11/100
602/602 [==============================] - 0s - loss: 7.2826 - acc: 0.5482     
Epoch 12/100
602/602 [==============================] - 0s - loss: 7.3361 - acc: 0.5449     
Epoch 13/100
602/602 [==============================] - 0s - loss: 7.2558 - acc: 0.5498     
Epoch 14/100
602/602 [==============================] - 0s - loss: 7.2023 - acc: 0.5532     
Epoch 15/100
602/602 [==============================] - 0s - loss: 7.2826 - acc: 0.5482     
Epoch 16/100
602/602 [==============================] - 0s - loss: 7.2826 - acc: 0.5482     
Epoch 17/100
602/602 [==============================] - 0s - loss: 6.8919 - acc: 0.5698     
Epoch 18/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 19/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 20/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 21/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 22/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 23/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714        
Epoch 24/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 25/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 26/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 27/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 28/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 29/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 30/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 31/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 32/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 33/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 34/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 35/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 36/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 37/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 38/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 39/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 40/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 41/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 42/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 43/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 44/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 45/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 46/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 47/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 48/100
602/602 [==============================] - 0s - loss: 6.9079 - acc: 0.5714     
Epoch 49/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 50/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 51/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 52/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 53/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 54/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 55/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 56/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 57/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 58/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 59/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 60/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 61/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 62/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 63/100
602/602 [==============================] - 0s - loss: 6.9100 - acc: 0.5714     
Epoch 64/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 65/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 66/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 67/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 68/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 69/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 70/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 71/100
602/602 [==============================] - 0s - loss: 6.8947 - acc: 0.5714     
Epoch 72/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 73/100
602/602 [==============================] - 0s - loss: 6.8836 - acc: 0.5714     
Epoch 74/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 75/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 76/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 77/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 78/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 79/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 80/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 81/100
602/602 [==============================] - 0s - loss: 6.9345 - acc: 0.5698     
Epoch 82/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 83/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 84/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 85/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 86/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 87/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 88/100
602/602 [==============================] - 0s - loss: 6.9613 - acc: 0.5681     
Epoch 89/100
602/602 [==============================] - 0s - loss: 6.8953 - acc: 0.5714     
Epoch 90/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 91/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 92/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 93/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 94/100
602/602 [==============================] - 0s - loss: 6.9105 - acc: 0.5698     
Epoch 95/100
602/602 [==============================] - 0s - loss: 6.9371 - acc: 0.5698     
Epoch 96/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 97/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 98/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 99/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
Epoch 100/100
602/602 [==============================] - 0s - loss: 6.9078 - acc: 0.5714     
128/258 [=============>................] - ETA: 0s

acc: 0.57

In [36]:
from keras.utils import plot_model 
if save_figures:
    plot_model(model, show_shapes = True, to_file = 'fig_deeplearning_keras_mlp.png')

Conclusion

We compared 3 different approaches in machine learning (manual, autoML and deep learning) to predict SEEG electrode implantation accuracy. Although I consider the results not bad for a start, there is definitely more work to do (on more data!) to reach a point in which a model like this could be implemented in e.g. a planning system to assist the surgeon in avoiding trajectories that are predicted to have a high deviation.